import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import copy
import csv

import hyperparameters
from processing_data import casia_wav_route, segment_dataset, read_json
from data_set import make_dataset
from model import Light_SERNet_V1
from optim_utils import ExpLR




############################## 预处理数据集 ##############################

print("*"*25, " 开始预处理数据集 ", "*"*25)
print(
    """
    简单说明: 1. 统计数据集的信息，在 data 目录下生成 json 文件（文件存在则跳过）
             2. 根据生成的信息，将数据集划分为训练集与测试集，并规范每个音频文件为指定长度，划分好的数据重新存在 data 目录下（已划分则跳过）
    """
)

casia_wav_route()
segment_dataset(read_json("data/casia_wav_route.json"))

print("*"*25, " 数据集预处理完成 ", "*"*25)

##########################################################################




############################## 训练模型之前的一些准备 ##############################

# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据加载
train_dataset, test_dataset = make_dataset("data/casia_4_segment")
train_loader = DataLoader(train_dataset, hyperparameters.BATCH_SIZE)
test_loader = DataLoader(test_dataset, hyperparameters.BATCH_SIZE)

# 模型
net = Light_SERNet_V1(len(hyperparameters.CASIA_LABELS)).to(device)

# 损失函数
criterion = nn.CrossEntropyLoss().to(device)

# 优化器
optimizer = optim.Adam(net.parameters(), lr=hyperparameters.LEARNING_RATE)

# 学习率做动态调整
# 50 个 epoch 后，每 20 个epoch 对学习率进行调整，lr = lr * exp(?)
scheduler = ExpLR(optimizer, hyperparameters.LEARNING_RATE_DECAY_STEP, hyperparameters.LEARNING_RATE_DECAY_PARAMETERS)

# 最佳的验证损失，先设置为最大
best_val_loss = float("inf")

# 最佳模型
best_net = None

####################################################################################




# 计算损失和精确度
def cal_acc(net, data_load, loss):
    """  计算损失和精确度

    inputs:
        net: 训练的模型
        data_load: 需要计算的数据
        loss: 模型使用的损失函数
    output:
        先返回模型精确度，再返回模型的平均损失
    """
    net.eval()
    all_item = 0.0
    acc_item = 0.0
    all_loss = 0.0
    for x, y in data_load:
        all_item += y.shape[0]
        x, y = x.to(device), y.to(device)
        y_hat = net(x)
        all_loss += loss(y_hat, y).item()
        y_hat = y_hat.argmax(dim=1)
        acc_item += torch.eq(y, y_hat).sum().item()
    return acc_item / all_item, all_loss / all_item


# 开始训练
print("-"*25, " 开始训练 ", "-"*25)
with open("train_log.csv", "w") as f:  # 记录训练日志
    w_log = csv.writer(f)
    w_log.writerow(["Epoch", "train_loss", "train_acc", "test_loss", "test_acc", "lr"])

    for epoch in range(hyperparameters.EPOCHS):
        net.train()
        for x, y in tqdm(train_loader):
            x, y = x.to(device), y.to(device)
            y_hat = net(x)
            loss = criterion(y_hat, y)
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        train_acc, train_loss = cal_acc(net, train_loader, criterion)
        test_acc, test_loss = cal_acc(net, test_loader, criterion)

        print(
            "Epoch {:3d}/{:3d} train_loss: {:5.6f} train_acc: {:5.6f} test_loss: {:5.6f} test_acc: {:5.6f} lr: {:5.6f}".format(
                epoch + 1, hyperparameters.EPOCHS, train_loss, train_acc, test_loss, test_acc, scheduler.get_last_lr()[0]
            )
        )

        w_log.writerow([
            epoch + 1, train_loss, train_acc, test_loss, test_acc, scheduler.get_last_lr()[0]
        ])

        if test_loss < best_val_loss:
            best_val_loss = test_loss
            best_net = copy.deepcopy(net)

        if epoch >= hyperparameters.LEARNING_RATE_DECAY_STRATPOINT - hyperparameters.LEARNING_RATE_DECAY_STEP:
            scheduler.step()  # 调整学习率


# 保存最佳模型
torch.save(best_net.state_dict(), "Light_SERNet_V1.pth")
